import os
import tqdm
import multiprocessing
import numpy as np
from IPython import embed

def get_l2_dist(classes):
    class_num = len(classes)
    d = np.zeros((class_num, class_num),dtype='float32')
    for i in range(class_num):
        for j in range(class_num):
            d[i][j] = np.linalg.norm(glove_dict[classes[i]] - glove_dict[classes[j]])
    return d

def get_cos_dist(classes):
    class_num = len(classes)
    d = np.zeros((class_num, class_num),dtype='float32')
    for i in range(class_num):
        for j in range(class_num):
            a = glove_dict[classes[i]].reshape(300,1)
            b = glove_dict[classes[j]].reshape(300,1)
            c = float(np.dot(a.T, b))
            k = np.linalg.norm(a) * np.linalg.norm(b)
            
            d[i][j] = c / k
    return d
    
if __name__ == '__main__':
    '''
    glove_file_path = '/data2/dingqianggang/datasets/glove/glove.42B.300d.txt'
    glove_dict = dict()
    
    fr_glove = open(glove_file_path, 'r')
    i = 0
    for line in fr_glove.readlines():
        line = line.strip().split(' ')
        val = np.array(list(map(float, line[1:])))
        key = line[0]
        glove_dict[key] = val
        embed()
        if i%10000 == 0:
            print(i)
        i += 1
    
    embed()
    np.save('glove_dict.npy', glove_dict)
    '''
    glove_dict = np.load('glove_dict.npy', allow_pickle=True)
    glove_dict = glove_dict.item()
    #cifar 10
    cifar10_classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
    cifar100_classes = np.load('/data2/dingqianggang/loss/label-embedding-loss/data/cifar-100-python/meta', allow_pickle=True)['fine_label_names']
    
    cifar100_classes = [class_name.split('_')[-1] for class_name in cifar100_classes]
    
    d = get_l2_dist(cifar10_classes)
    np.save('cifar10_class_l2_dist.npy', d)

    d = get_cos_dist(cifar10_classes)
    np.save('cifar10_class_cos_dist.npy', d)
    
    d = get_l2_dist(cifar100_classes)
    np.save('cifar100_class_l2_dist.npy', d)

    d = get_cos_dist(cifar100_classes)
    np.save('cifar100_class_cos_dist.npy', d)
    
    